Stridedslicegrad

计算 StridedSlice(步长切片)操作的梯度。该算子是 StridedSlice 算子的反向传播(backward pass)部分。

该算子将上游传来的梯度(对应 StridedSlice 输出的形状)映射回原始输入的位置(对应 StridedSlice 输入的形状)。对于输入梯度中的每个元素,根据原始 StridedSlice 操作的 beginsstrides 参数,计算其在原始输入中的位置,并将梯度值写入该位置。

\[\text{output}[\text{idx}] = \text{inputs}[\text{pos}]\]

其中 idx 是根据 posin_shape 中的多维索引、stridesbegins 计算出的在 dx_shape 中的线性索引。

输入:
  • inputs - 上游传来的梯度张量数据地址(即 \(dy\)),形状为 :math:`in_shape`(原始 StridedSlice 操作的输出形状)。

  • dx_shape - 输出梯度张量的形状数组(int*),大小为8,对应原始 StridedSlice 操作的输入形状。对于维度小于8的张量,高位维度形状为1。

  • strides - 原始 StridedSlice 操作的步长数组(int*),大小为8。对于维度小于8的张量,高位维度步长为1。

  • begins - 原始 StridedSlice 操作的起始索引数组(int*),大小为8。对于维度小于8的张量,高位维度起始索引为0。

  • in_shape - 原始 StridedSlice 操作的输出形状数组(int*),大小为8,即输入梯度 inputs 的形状。对于维度小于8的张量,高位维度形状为1。

  • core_mask - 核掩码(int),仅共享存储版本需要。

输出:
  • output - 输出梯度张量数据地址(即 \(dx\)),形状为 :math:`dx_shape`(原始 StridedSlice 操作的输入形状)。该张量在调用前通常被初始化为全零。

支持平台:

FT78NE MT7004

备注

  • MT7004 支持fp16, fp32

  • FT78NE 支持fp32

  • 输出张量 output 在调用前需要预先初始化为全零

  • 形状数组固定为8维,对于维度小于8的张量,高位维度形状为1

共享存储版本:

void hp_stridedslicegrad_s(half *inputs, half *output, int *dx_shape, int *strides, int *begins, int *in_shape, int core_mask)
void fp_stridedslicegrad_s(float *inputs, float *output, int *dx_shape, int *strides, int *begins, int *in_shape, int core_mask)

C调用示例:

 1//MT7004示例
 2#include <stdio.h>
 3#include <stridedslicegrad.h>
 4
 5int main(int argc, char* argv[]) {
 6    // 假设在DDR空间
 7    // 原始 StridedSlice 操作:
 8    // 输入形状 [2, 3, 4, 5]
 9    // 输出形状 [1, 2, 2, 3]
10    // begins = [0, 1, 1, 2], strides = [1, 1, 2, 1]
11
12    // 输出梯度形状(原始输入形状)
13    int dx_shape[8] = {2, 3, 4, 5, 1, 1, 1, 1};
14
15    // 输入梯度形状(原始输出形状)
16    int in_shape[8] = {1, 2, 2, 3, 1, 1, 1, 1};
17
18    // 原始 StridedSlice 参数
19    int begins[8] = {0, 1, 1, 2, 0, 0, 0, 0};
20    int strides[8] = {1, 1, 2, 1, 1, 1, 1, 1};
21
22    // 输入梯度(上游传来的梯度)
23    float *inputs = (float *)0xA0000000;  // 形状为 in_shape
24    // inputs 包含 1 * 2 * 2 * 3 = 12 个元素
25
26    // 输出梯度(待计算)
27    float *output = (float *)0xB0000000;  // 形状为 dx_shape
28    // output 包含 2 * 3 * 4 * 5 = 120 个元素
29
30    // 初始化输出为全零
31    memset(output, 0, 120 * sizeof(float));
32
33    int core_mask = 0xff;
34
35    fp_stridedslicegrad_s(inputs, output, dx_shape, strides, begins, in_shape, core_mask);
36
37    return 0;
38}

私有存储版本:

void hp_stridedslicegrad_p(half *inputs, half *output, int *dx_shape, int *strides, int *begins, int *in_shape)
void fp_stridedslicegrad_p(float *inputs, float *output, int *dx_shape, int *strides, int *begins, int *in_shape)

C调用示例:

 1//MT7004示例
 2#include <stdio.h>
 3#include <stridedslicegrad.h>
 4
 5int main(int argc, char* argv[]) {
 6    // 假设在L2空间
 7    int dx_shape[8] = {2, 3, 4, 5, 1, 1, 1, 1};
 8    int in_shape[8] = {1, 2, 2, 3, 1, 1, 1, 1};
 9
10    int begins[8] = {0, 1, 1, 2, 0, 0, 0, 0};
11    int strides[8] = {1, 1, 2, 1, 1, 1, 1, 1};
12
13    float *inputs = (float *)0x10000000;
14    float *output = (float *)0x10001000;
15
16    // 初始化输出为全零
17    memset(output, 0, 120 * sizeof(float));
18
19    fp_stridedslicegrad_p(inputs, output, dx_shape, strides, begins, in_shape);
20
21    return 0;
22}